""" DaViT: Dual Attention Vision Transformers

As described in https://arxiv.org/abs/2204.03645

Input size invariant transformer architecture that combines channel and spacial
attention in each block. The attention mechanisms used are linear in complexity.

DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below

"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
from functools import partial
from typing import List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, calculate_drop_path_rates, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint, checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['DaVit']


class ConvPosEnc(nn.Module):
    def __init__(
            self,
            dim: int,
            k: int = 3,
            act: bool = False,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()

        self.proj = nn.Conv2d(
            dim,
            dim,
            kernel_size=k,
            stride=1,
            padding=k // 2,
            groups=dim,
            **dd,
        )
        self.act = nn.GELU() if act else nn.Identity()

    def forward(self, x: Tensor):
        feat = self.proj(x)
        x = x + self.act(feat)
        return x


class Stem(nn.Module):
    """ Size-agnostic implementation of 2D image to patch embedding,
        allowing input size to be adjusted during model forward operation
    """

    def __init__(
            self,
            in_chs: int = 3,
            out_chs: int = 96,
            stride: int = 4,
            norm_layer: Type[nn.Module] = LayerNorm2d,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        stride = to_2tuple(stride)
        self.stride = stride
        self.in_chs = in_chs
        self.out_chs = out_chs
        assert stride[0] == 4  # only setup for stride==4
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=7,
            stride=stride,
            padding=3,
            **dd,
        )
        self.norm = norm_layer(out_chs, **dd)

    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
        pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
        x = F.pad(x, (0, pad_r, 0, pad_b))
        x = self.conv(x)
        x = self.norm(x)
        return x


class Downsample(nn.Module):
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            kernel_size: int = 3,
            norm_layer: Type[nn.Module] = LayerNorm2d,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_chs = in_chs
        self.out_chs = out_chs

        self.norm = norm_layer(in_chs, **dd)
        self.even_k = kernel_size % 2 == 0
        self.conv = nn.Conv2d(
            in_chs,
            out_chs,
            kernel_size=kernel_size,
            stride=2,
            padding=0 if self.even_k else kernel_size // 2,
            **dd,
        )

    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        x = self.norm(x)
        if self.even_k:
            k_h, k_w = self.conv.kernel_size
            pad_r = (k_w - W % k_w) % k_w
            pad_b = (k_h - H % k_h) % k_h
            x = F.pad(x, (0, pad_r , 0, pad_b))
        x = self.conv(x)
        return x


class ChannelAttentionV2(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = True,
            dynamic_scale: bool = True,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        self.groups = num_heads
        self.head_dim = dim // num_heads
        self.dynamic_scale = dynamic_scale

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
        self.proj = nn.Linear(dim, dim, **dd)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        if self.dynamic_scale:
            q = q * N ** -0.5
        else:
            q = q * self.head_dim ** -0.5
        attn = q.transpose(-1, -2) @ k
        attn = attn.softmax(dim=-1)
        x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x



class ChannelAttention(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
        self.proj = nn.Linear(dim, dim, **dd)

    def forward(self, x: Tensor):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        k = k * self.scale
        attn = k.transpose(-1, -2) @ v
        attn = attn.softmax(dim=-1)
        x = (attn @ q.transpose(-1, -2)).transpose(-1, -2)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x


class ChannelBlock(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            drop_path: float = 0.,
            act_layer: Type[nn.Module] = nn.GELU,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            ffn: bool = True,
            cpe_act: bool = False,
            v2: bool = False,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()

        self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
        self.ffn = ffn
        self.norm1 = norm_layer(dim, **dd)
        attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
        self.attn = attn_layer(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            **dd,
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)

        if self.ffn:
            self.norm2 = norm_layer(dim, **dd)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                act_layer=act_layer,
                **dd,
            )
            self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        else:
            self.norm2 = None
            self.mlp = None
            self.drop_path2 = None

    def forward(self, x: Tensor):
        B, C, H, W = x.shape

        x = self.cpe1(x).flatten(2).transpose(1, 2)

        cur = self.norm1(x)
        cur = self.attn(cur)
        x = x + self.drop_path1(cur)

        x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))

        if self.mlp is not None:
            x = x.flatten(2).transpose(1, 2)
            x = x + self.drop_path2(self.mlp(self.norm2(x)))
            x = x.transpose(1, 2).view(B, C, H, W)

        return x


def window_partition(x: Tensor, window_size: Tuple[int, int]):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    return windows


@register_notrace_function  # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    C = windows.shape[-1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
    """
    fused_attn: torch.jit.Final[bool]

    def __init__(
            self,
            dim: int,
            window_size: Tuple[int, int],
            num_heads: int,
            qkv_bias: bool = True,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
        self.proj = nn.Linear(dim, dim, **dd)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: Tensor):
        B_, N, C = x.shape

        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(q, k, v)
        else:
            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))
            attn = self.softmax(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x


class SpatialBlock(nn.Module):
    r""" Windows Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(
            self,
            dim: int,
            num_heads: int,
            window_size: int = 7,
            mlp_ratio: float = 4.,
            qkv_bias: bool = True,
            drop_path: float = 0.,
            act_layer: Type[nn.Module] = nn.GELU,
            norm_layer: Type[nn.Module] = nn.LayerNorm,
            ffn: bool = True,
            cpe_act: bool = False,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()
        self.dim = dim
        self.ffn = ffn
        self.num_heads = num_heads
        self.window_size = to_2tuple(window_size)
        self.mlp_ratio = mlp_ratio

        self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
        self.norm1 = norm_layer(dim, **dd)
        self.attn = WindowAttention(
            dim,
            self.window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            **dd,
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act, **dd)
        if self.ffn:
            self.norm2 = norm_layer(dim, **dd)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=mlp_hidden_dim,
                act_layer=act_layer,
                **dd,
            )
            self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        else:
            self.norm2 = None
            self.mlp = None
            self.drop_path1 = None

    def forward(self, x: Tensor):
        B, C, H, W = x.shape

        shortcut = self.cpe1(x).flatten(2).transpose(1, 2)

        x = self.norm1(shortcut)
        x = x.view(B, H, W, C)

        pad_l = pad_t = 0
        pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
        pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)

        # if pad_r > 0 or pad_b > 0:
        x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path1(x)

        x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))

        if self.mlp is not None:
            x = x.flatten(2).transpose(1, 2)
            x = x + self.drop_path2(self.mlp(self.norm2(x)))
            x = x.transpose(1, 2).view(B, C, H, W)

        return x


class DaVitStage(nn.Module):
    def __init__(
            self,
            in_chs: int,
            out_chs: int,
            depth:int = 1,
            downsample: bool = True,
            attn_types: Tuple[str, ...] = ('spatial', 'channel'),
            num_heads: int = 3,
            window_size: int = 7,
            mlp_ratio: float = 4.,
            qkv_bias: bool = True,
            drop_path_rates: Tuple[float, ...] = (0, 0),
            norm_layer: Type[nn.Module] = LayerNorm2d,
            norm_layer_cl: Type[nn.Module] = nn.LayerNorm,
            ffn: bool = True,
            cpe_act: bool = False,
            down_kernel_size: int = 2,
            named_blocks: bool = False,
            channel_attn_v2: bool = False,
            device=None,
            dtype=None,
    ):
        dd = {'device': device, 'dtype': dtype}
        super().__init__()

        self.grad_checkpointing = False

        # downsample embedding layer at the beginning of each stage
        if downsample:
            self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer, **dd)
        else:
            self.downsample = nn.Identity()

        '''
         repeating alternating attention blocks in each stage
         default: (spatial -> channel) x depth

         potential opportunity to integrate with a more general version of ByobNet/ByoaNet
         since the logic is similar
        '''
        stage_blocks = []
        for block_idx in range(depth):
            from collections import OrderedDict
            dual_attention_block = []
            for attn_idx, attn_type in enumerate(attn_types):
                if attn_type == 'spatial':
                    dual_attention_block.append(('spatial_block', SpatialBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[block_idx],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act,
                        window_size=window_size,
                        **dd,
                    )))
                elif attn_type == 'channel':
                    dual_attention_block.append(('channel_block', ChannelBlock(
                        dim=out_chs,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop_path=drop_path_rates[block_idx],
                        norm_layer=norm_layer_cl,
                        ffn=ffn,
                        cpe_act=cpe_act,
                        v2=channel_attn_v2,
                        **dd,
                    )))
            if named_blocks:
                stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
            else:
                stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
        self.blocks = nn.Sequential(*stage_blocks)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable

    def forward(self, x: Tensor):
        x = self.downsample(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        return x


class DaVit(nn.Module):
    r""" DaViT
        A PyTorch implementation of `DaViT: Dual Attention Vision Transformers`  - https://arxiv.org/abs/2204.03645
        Supports arbitrary input sizes and pyramid feature extraction

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
        embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
        num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
    """

    def __init__(
            self,
            in_chans: int = 3,
            depths: Tuple[int, ...] = (1, 1, 3, 1),
            embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
            num_heads: Tuple[int, ...] = (3, 6, 12, 24),
            window_size: int = 7,
            mlp_ratio: float = 4,
            qkv_bias: bool = True,
            norm_layer: str = 'layernorm2d',
            norm_layer_cl: str = 'layernorm',
            norm_eps: float = 1e-5,
            attn_types: Tuple[str, ...] = ('spatial', 'channel'),
            ffn: bool = True,
            cpe_act: bool = False,
            down_kernel_size: int = 2,
            channel_attn_v2: bool = False,
            named_blocks: bool = False,
            drop_rate: float = 0.,
            drop_path_rate: float = 0.,
            num_classes: int = 1000,
            global_pool: str = 'avg',
            head_norm_first: bool = False,
            device=None,
            dtype=None,
    ):
        super().__init__()
        dd = {'device': device, 'dtype': dtype}
        num_stages = len(embed_dims)
        assert num_stages == len(num_heads) == len(depths)
        norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
        norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
        self.num_classes = num_classes
        self.num_features = self.head_hidden_size = embed_dims[-1]
        self.drop_rate = drop_rate
        self.grad_checkpointing = False
        self.feature_info = []

        self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer, **dd)
        in_chs = embed_dims[0]

        dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
        stages = []
        for i in range(num_stages):
            out_chs = embed_dims[i]
            stage = DaVitStage(
                in_chs,
                out_chs,
                depth=depths[i],
                downsample=i > 0,
                attn_types=attn_types,
                num_heads=num_heads[i],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop_path_rates=dpr[i],
                norm_layer=norm_layer,
                norm_layer_cl=norm_layer_cl,
                ffn=ffn,
                cpe_act=cpe_act,
                down_kernel_size=down_kernel_size,
                channel_attn_v2=channel_attn_v2,
                named_blocks=named_blocks,
                **dd,
            )
            in_chs = out_chs
            stages.append(stage)
            self.feature_info += [dict(num_chs=out_chs, reduction=2**(i+2), module=f'stages.{i}')]

        self.stages = nn.Sequential(*stages)

        # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
        # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
        # FIXME generalize this structure to ClassifierHead
        if head_norm_first:
            self.norm_pre = norm_layer(self.num_features, **dd)
            self.head = ClassifierHead(
                self.num_features,
                num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
                **dd,
            )
        else:
            self.norm_pre = nn.Identity()
            self.head = NormMlpClassifierHead(
                self.num_features,
                num_classes,
                pool_type=global_pool,
                drop_rate=self.drop_rate,
                norm_layer=norm_layer,
                **dd,
            )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r'^stem',  # stem and embed
            blocks=r'^stages\.(\d+)' if coarse else [
                (r'^stages\.(\d+).downsample', (0,)),
                (r'^stages\.(\d+)\.blocks\.(\d+)', None),
                (r'^norm_pre', (99999,)),
            ]
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable
        for stage in self.stages:
            stage.set_grad_checkpointing(enable=enable)

    @torch.jit.ignore
    def get_classifier(self) -> nn.Module:
        return self.head.fc

    def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
        self.num_classes = num_classes
        self.head.reset(num_classes, global_pool)

    def forward_intermediates(
            self,
            x: torch.Tensor,
            indices: Optional[Union[int, List[int]]] = None,
            norm: bool = False,
            stop_early: bool = False,
            output_fmt: str = 'NCHW',
            intermediates_only: bool = False,
    ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
        """ Forward features that returns intermediates.

        Args:
            x: Input image tensor
            indices: Take last n blocks if int, all if None, select matching indices if sequence
            norm: Apply norm layer to compatible intermediates
            stop_early: Stop iterating over blocks when last desired intermediate hit
            output_fmt: Shape of intermediate feature outputs
            intermediates_only: Only return intermediate features
        Returns:

        """
        assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
        intermediates = []
        take_indices, max_index = feature_take_indices(len(self.stages), indices)

        # forward pass
        x = self.stem(x)
        last_idx = len(self.stages) - 1
        if torch.jit.is_scripting() or not stop_early:  # can't slice blocks in torchscript
            stages = self.stages
        else:
            stages = self.stages[:max_index + 1]

        for feat_idx, stage in enumerate(stages):
            if self.grad_checkpointing and not torch.jit.is_scripting():
                x = checkpoint(stage, x)
            else:
                x = stage(x)
            if feat_idx in take_indices:
                if norm and feat_idx == last_idx:
                    x_inter = self.norm_pre(x)  # applying final norm to last intermediate
                else:
                    x_inter = x
                intermediates.append(x_inter)

        if intermediates_only:
            return intermediates

        if feat_idx == last_idx:
            x = self.norm_pre(x)

        return x, intermediates

    def prune_intermediate_layers(
            self,
            indices: Union[int, List[int]] = 1,
            prune_norm: bool = False,
            prune_head: bool = True,
    ):
        """ Prune layers not required for specified intermediates.
        """
        take_indices, max_index = feature_take_indices(len(self.stages), indices)
        self.stages = self.stages[:max_index + 1]  # truncate blocks w/ stem as idx 0
        if prune_norm:
            self.norm_pre = nn.Identity()
        if prune_head:
            self.reset_classifier(0, '')
        return take_indices

    def forward_features(self, x):
        x = self.stem(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.stages, x)
        else:
            x = self.stages(x)
        x = self.norm_pre(x)
        return x

    def forward_head(self, x, pre_logits: bool = False):
        return self.head(x, pre_logits=True) if pre_logits else self.head(x)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x


def _convert_florence2(state_dict, model, prefix='vision_tower.'):
    import re
    out_dict = {}

    for k, v in state_dict.items():
        if k.startswith(prefix):
            k = k.replace(prefix, '')
        else:
            continue
        k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
        k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
        k = k.replace('downsample.proj', 'downsample.conv')
        k = k.replace('stages.0.downsample', 'stem')
        #k = k.replace('head.', 'head.fc.')
        #k = k.replace('norms.', 'head.norm.')
        k = k.replace('window_attn.norm.', 'norm1.')
        k = k.replace('window_attn.fn.', 'attn.')
        k = k.replace('channel_attn.norm.', 'norm1.')
        k = k.replace('channel_attn.fn.', 'attn.')
        k = k.replace('ffn.norm.', 'norm2.')
        k = k.replace('ffn.fn.net.', 'mlp.')
        k = k.replace('conv1.fn.dw', 'cpe1.proj')
        k = k.replace('conv2.fn.dw', 'cpe2.proj')
        out_dict[k] = v

    return out_dict


def checkpoint_filter_fn(state_dict, model):
    """ Remap MSFT checkpoints -> timm """
    if 'head.fc.weight' in state_dict:
        return state_dict  # non-MSFT checkpoint

    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']

    if 'vision_tower.convs.0.proj.weight' in state_dict:
        return _convert_florence2(state_dict, model)

    import re
    out_dict = {}
    for k, v in state_dict.items():
        k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
        k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
        k = k.replace('downsample.proj', 'downsample.conv')
        k = k.replace('stages.0.downsample', 'stem')
        k = k.replace('head.', 'head.fc.')
        k = k.replace('norms.', 'head.norm.')
        k = k.replace('cpe.0', 'cpe1')
        k = k.replace('cpe.1', 'cpe2')
        out_dict[k] = v
    return out_dict


def _create_davit(variant, pretrained=False, **kwargs):
    default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
    out_indices = kwargs.pop('out_indices', default_out_indices)

    strict = kwargs.pop('pretrained_strict', True)
    if variant.endswith('_fl'):
        # FIXME cleaner approach to missing head norm?
        strict = False

    model = build_model_with_cfg(
        DaVit,
        variant,
        pretrained,
        pretrained_filter_fn=checkpoint_filter_fn,
        feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
        pretrained_strict=strict,
        **kwargs)

    return model


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
        'crop_pct': 0.95, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'stem.conv', 'classifier': 'head.fc',
        'license': 'apache-2.0',
        **kwargs
    }


# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
    # official microsoft weights from https://github.com/dingmyu/davit
    'davit_tiny.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_small.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_base.msft_in1k': _cfg(
        hf_hub_id='timm/'),
    'davit_large': _cfg(),
    'davit_huge': _cfg(),
    'davit_giant': _cfg(),
    'davit_base_fl.msft_florence2': _cfg(
        hf_hub_id='microsoft/Florence-2-base',
        num_classes=0, input_size=(3, 768, 768)),
    'davit_huge_fl.msft_florence2': _cfg(
        hf_hub_id='microsoft/Florence-2-large',
        num_classes=0, input_size=(3, 768, 768)),
})


@register_model
def davit_tiny(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
    return _create_davit('davit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_small(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24))
    return _create_davit('davit_small', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_base(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32))
    return _create_davit('davit_base', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_large(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48))
    return _create_davit('davit_large', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_huge(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64))
    return _create_davit('davit_huge', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_giant(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
    return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))



@register_model
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
    model_args = dict(
        depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
        window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
    )
    return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
    # NOTE: huge image tower used in 'large' Florence2 model
    model_args = dict(
        depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
        window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
    )
    return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))
